{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:10.620382Z",
     "start_time": "2020-03-30T15:26:10.615475Z"
    }
   },
   "outputs": [],
   "source": [
    "def warn(*args, **kwargs):\n",
    "    pass\n",
    "import warnings\n",
    "warnings.warn = warn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:11.013555Z",
     "start_time": "2020-03-30T15:26:10.974655Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from time import time\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:13.290928Z",
     "start_time": "2020-03-30T15:26:13.272927Z"
    }
   },
   "outputs": [],
   "source": [
    "from keras.utils import Sequence\n",
    "from keras.utils import np_utils\n",
    "\n",
    "class DataGenerator(Sequence):\n",
    "    \"\"\"Data Generator inherited from keras.utils.Sequence\n",
    "    Args: \n",
    "        directory: the path of data set, and each sub-folder will be assigned to one class\n",
    "        batch_size: the number of data points in each batch\n",
    "        shuffle: whether to shuffle the data per epoch\n",
    "    Note:\n",
    "        If you want to load file with other data format, please fix the method of \"load_data\" as you want\n",
    "    \"\"\"\n",
    "    def __init__(self, directory, batch_size=1, shuffle=True, data_augmentation=True):\n",
    "        # Initialize the params\n",
    "        self.batch_size = batch_size\n",
    "        self.directory = directory\n",
    "        self.shuffle = shuffle\n",
    "        self.data_aug = data_augmentation\n",
    "        # Load all the save_path of files, and create a dictionary that save the pair of \"data:label\"\n",
    "        self.X_path, self.Y_dict = self.search_data() \n",
    "        # Print basic statistics information\n",
    "        self.print_stats()\n",
    "        return None\n",
    "        \n",
    "    def search_data(self):\n",
    "        X_path = []\n",
    "        Y_dict = {}\n",
    "        # list all kinds of sub-folders\n",
    "        self.dirs = sorted(os.listdir(self.directory))\n",
    "        one_hots = np_utils.to_categorical(range(len(self.dirs)))\n",
    "        for i,folder in enumerate(self.dirs):\n",
    "            folder_path = os.path.join(self.directory,folder)\n",
    "            for file in os.listdir(folder_path):\n",
    "                file_path = os.path.join(folder_path,file)\n",
    "                # append the each file path, and keep its label  \n",
    "                X_path.append(file_path)\n",
    "                Y_dict[file_path] = one_hots[i]\n",
    "        return X_path, Y_dict\n",
    "    \n",
    "    def print_stats(self):\n",
    "        # calculate basic information\n",
    "        self.n_files = len(self.X_path)\n",
    "        self.n_classes = len(self.dirs)\n",
    "        self.indexes = np.arange(len(self.X_path))\n",
    "        np.random.shuffle(self.indexes)\n",
    "        # Output states\n",
    "        print(\"Found {} files belonging to {} classes.\".format(self.n_files,self.n_classes))\n",
    "        for i,label in enumerate(self.dirs):\n",
    "            print('%10s : '%(label),i)\n",
    "        return None\n",
    "    \n",
    "    def __len__(self):\n",
    "        # calculate the iterations of each epoch\n",
    "        steps_per_epoch = np.ceil(len(self.X_path) / float(self.batch_size))\n",
    "        return int(steps_per_epoch)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        \"\"\"Get the data of each batch\n",
    "        \"\"\"\n",
    "        # get the indexs of each batch\n",
    "        batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]\n",
    "        # using batch_indexs to get path of current batch\n",
    "        batch_path = [self.X_path[k] for k in batch_indexs]\n",
    "        # get batch data\n",
    "        batch_x, batch_y = self.data_generation(batch_path)\n",
    "        return batch_x, batch_y\n",
    "\n",
    "    def on_epoch_end(self):\n",
    "        # shuffle the data at each end of epoch\n",
    "        if self.shuffle == True:\n",
    "            np.random.shuffle(self.indexes)\n",
    "\n",
    "    def data_generation(self, batch_path):\n",
    "        # load data into memory, you can change the np.load to any method you want\n",
    "        batch_x = [self.load_data(x) for x in batch_path]\n",
    "        batch_y = [self.Y_dict[x] for x in batch_path]\n",
    "        # transfer the data format and take one-hot coding for labels\n",
    "        batch_x = np.array(batch_x)\n",
    "        batch_y = np.array(batch_y)\n",
    "        return batch_x, batch_y\n",
    "      \n",
    "    def normalize(self, data):\n",
    "        mean = np.mean(data)\n",
    "        std = np.std(data)\n",
    "        return (data-mean) / std\n",
    "    \n",
    "    def uniform_sampling(self, video, target_frames=64):\n",
    "        # get total frames of input video and calculate sampling interval \n",
    "        len_frames = int(len(video))\n",
    "        interval = int(np.ceil(len_frames/target_frames))\n",
    "        # init empty list for sampled video and \n",
    "        sampled_video = []\n",
    "        for i in range(0,len_frames,interval):\n",
    "            sampled_video.append(video[i])     \n",
    "        # calculate numer of padded frames and fix it \n",
    "        num_pad = target_frames - len(sampled_video)\n",
    "        padding = []\n",
    "        if num_pad>0:\n",
    "            for i in range(-num_pad,0):\n",
    "                try: \n",
    "                    padding.append(video[i])\n",
    "                except:\n",
    "                    padding.append(video[0])\n",
    "            sampled_video += padding     \n",
    "        # get sampled video\n",
    "        return np.array(sampled_video, dtype=np.float32)\n",
    "    \n",
    "    def load_data(self, path):\n",
    "        # in this project, just load the first 3 channels which containing RGB images\n",
    "        data = np.load(path, mmap_mode='r')[...,3:]\n",
    "        data = np.float32(data)\n",
    "        # sampling 64 frames uniformly from the entire video \n",
    "        data = self.uniform_sampling(video=data, target_frames=64)\n",
    "       # normalize the data into means=0, variance=1\n",
    "        data = self.normalize(data)\n",
    "        return data\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:13.780918Z",
     "start_time": "2020-03-30T15:26:13.775104Z"
    }
   },
   "outputs": [],
   "source": [
    "from base_model import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:18.217573Z",
     "start_time": "2020-03-30T15:26:14.260675Z"
    }
   },
   "outputs": [],
   "source": [
    "model = Inception_Inflated3d(include_top=True,\n",
    "                #weights='rgb_imagenet_and_kinetics',\n",
    "                input_tensor=None,\n",
    "                input_shape=(64,224,224,2),\n",
    "                dropout_prob=0.5,\n",
    "                endpoint_logit=False,\n",
    "                classes=2)\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set the GPUs and make it parallel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:26.358914Z",
     "start_time": "2020-03-30T15:26:18.225015Z"
    }
   },
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2,3\"\n",
    "\n",
    "from keras.utils import multi_gpu_model   \n",
    "parallel_model = multi_gpu_model(model, gpus=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Compiling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:26.414633Z",
     "start_time": "2020-03-30T15:26:26.366909Z"
    }
   },
   "outputs": [],
   "source": [
    "from keras.optimizers import Adam, SGD\n",
    "\n",
    "sgd  = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)\n",
    "parallel_model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Learning Rate scheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:26.425939Z",
     "start_time": "2020-03-30T15:26:26.421787Z"
    }
   },
   "outputs": [],
   "source": [
    "import keras.backend as K\n",
    "from keras.callbacks import LearningRateScheduler\n",
    "\n",
    "def scheduler(epoch):\n",
    "    if epoch % 10 == 0 and epoch != 0:\n",
    "        lr = K.get_value(parallel_model.optimizer.lr)\n",
    "        K.set_value(parallel_model.optimizer.lr, lr * 0.8)\n",
    "    return K.get_value(parallel_model.optimizer.lr)\n",
    "\n",
    "reduce_lr = LearningRateScheduler(scheduler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Saving the best model and training logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:26.438782Z",
     "start_time": "2020-03-30T15:26:26.433545Z"
    }
   },
   "outputs": [],
   "source": [
    "from keras.callbacks import ModelCheckpoint, CSVLogger\n",
    "\n",
    "class ParallelModelCheckpoint(ModelCheckpoint):\n",
    "    def __init__(self,model,filepath, monitor='val_loss', verbose=0,\n",
    "                 save_best_only=False, save_weights_only=False,\n",
    "                 mode='auto', period=1):\n",
    "        self.single_model = model\n",
    "        super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)\n",
    "\n",
    "    def set_model(self, model):\n",
    "        super(ParallelModelCheckpoint,self).set_model(self.single_model)\n",
    "\n",
    "check_point = ParallelModelCheckpoint(parallel_model ,'Logs/I3D-flow.hd5')\n",
    "\n",
    "\n",
    "\n",
    "filename = 'Logs/I3D-flow_log.csv'\n",
    "csv_logger = CSVLogger(filename, separator=',', append=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:26.447750Z",
     "start_time": "2020-03-30T15:26:26.445246Z"
    }
   },
   "outputs": [],
   "source": [
    "callbacks_list = [check_point, csv_logger, reduce_lr]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- set essential params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:28.588483Z",
     "start_time": "2020-03-30T15:26:28.583544Z"
    }
   },
   "outputs": [],
   "source": [
    "num_epochs  = 50\n",
    "num_workers = 16\n",
    "batch_size  = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- init data generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:26:29.517318Z",
     "start_time": "2020-03-30T15:26:29.505793Z"
    }
   },
   "outputs": [],
   "source": [
    "dataset = 'RWF2000'\n",
    "train_generator = DataGenerator(directory='../Datasets/{}/train'.format(dataset), \n",
    "                                batch_size=batch_size,\n",
    "                                data_augmentation=True)\n",
    "\n",
    "val_generator = DataGenerator(directory='../Datasets/{}/val'.format(dataset),\n",
    "                              batch_size=batch_size, \n",
    "                              data_augmentation=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- start to train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-30T15:48:50.159577Z",
     "start_time": "2020-03-30T15:26:30.519341Z"
    }
   },
   "outputs": [],
   "source": [
    "hist = parallel_model.fit_generator(\n",
    "    generator=train_generator, \n",
    "    validation_data=val_generator,\n",
    "    callbacks=callbacks_list,\n",
    "    verbose=1, \n",
    "    epochs=num_epochs,\n",
    "    workers=num_workers ,\n",
    "    max_queue_size=4,\n",
    "    steps_per_epoch=len(train_generator),\n",
    "    validation_steps=len(val_generator))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
